import re
import pickle

from math import log
# see https://zhuanlan.zhihu.com/p/245372320
class ChineseTokenizer(object):

    re_han_default = re.compile("([\u4E00-\u9FD5a-zA-Z0-9+#&\._%\-]+)", re.U)
    re_skip_default = re.compile("(\r\n|\s)", re.U)

    # 用于提取连续的汉字部分
    re_han = re.compile("([\u4E00-\u9FD5]+)")
    # 用于分割连续的非汉字部分
    re_skip = re.compile("([a-zA-Z0-9\.]+(?:\.\d+)?%?)")

    MIN_FLOAT = -3.14e100

    @staticmethod
    def get_prefix_dict(f_name):
        lfreq = {}
        ltotal = 0
        f = open(f_name)
        for lineno, line in enumerate(f, 1):
            try:
                line = line.strip()
                word, freq = line.split(' ')[:2]
                freq = int(freq)
                lfreq[word] = freq
                ltotal += freq
                for ch in range(len(word)):
                    wfrag = word[:ch + 1]
                    if wfrag not in lfreq:
                        lfreq[wfrag] = 0
            except ValueError:
                raise ValueError(
                    'invalid dictionary entry in %s at Line %s: %s' % (f_name, lineno, line))
        f.close()
        return lfreq, ltotal

    def __init__(self):
        self.freq, self.total = self.get_prefix_dict("./assets/dict.txt")  # 前缀词典
        self.prob_start = pickle.load(open("./assets/prob_start.p", "rb"))  # 初始概率参数
        self.prob_emit = pickle.load(open("./assets/prob_emit.p", "rb"))  # 发射概率
        self.prob_trans = pickle.load(open("./assets/prob_trans.p", "rb"))  # 状态转移概率

    def cut(self, sentence):
        blocks = self.re_han_default.split(sentence)
        for blk in blocks:
            # 处理空字符串
            if not blk:
                continue
            if self.re_han_default.match(blk):
                # 处理子句
                for word in self.cut_block(blk):
                    yield word
            else:
                # 处理标点符号、空格等等
                tmp = self.re_skip_default.split(blk)
                for x in tmp:
                    if self.re_skip_default.match(x):
                        # 空格、制表符、换行等一起返回
                        yield x
                    else:
                        # 标点符号等分割成字符返回
                        for xx in x:
                            yield xx


    def cut_block(self, sentence):
        DAG = self.get_DAG(sentence)
        route = self.clac(sentence, DAG)
        x = 0
        buf = ''
        N = len(sentence)
        while x < N:
            y = route[x][1] + 1
            l_word = sentence[x:y]

            # 如果当前为一个字符，加入buffer待HMM进一步分词
            if y - x == 1:
                buf += l_word
            else:
                # 对当前buffer进行分词
                if buf:
                    # 当前buffer只有一个字符，直接yield
                    if len(buf) == 1:
                        yield buf
                        buf = ''
                    else:
                        # 这里加了一层判断，如果词典中存在和当前buffer相同的词，则不需要再用HMM进行切分了。
                        if not self.freq.get(buf):
                            # 讲buffer送入HMM进行分词
                            recognized = self.cut_regx_hmm(buf)
                            for t in recognized:
                                yield t
                        else:
                            for elem in buf:
                                yield elem
                        buf = ''
                yield l_word
            x = y

        # 跳出循环后，可能还有待处理的buffer，进行处理
        if buf:
            if len(buf) == 1:
                yield buf
            elif not self.freq.get(buf):
                recognized = self.cut_regx_hmm(buf)
                for t in recognized:
                    yield t
            else:
                for elem in buf:
                    yield elem

    def cut_regx_hmm(self, sentence):
        blocks = self.re_han.split(sentence)
        for block in blocks:
            if not block:
                continue
            if self.re_han.match(block):
                yield from self.cut_hmm(block)
            else:
                for ss in self.re_skip.split(block):
                    if ss:
                        yield ss

    def cut_hmm(self, sentence):
        prob, pos_list = self.viterbi(sentence, 'BMES')
        begin, nexti = 0, 0
        # print pos_list, sentence
        for i, char in enumerate(sentence):
            pos = pos_list[i]
            if pos == 'B':
                begin = i
            elif pos == 'E':
                yield sentence[begin:i + 1]
                nexti = i + 1
            elif pos == 'S':
                yield char
                nexti = i + 1
        if nexti < len(sentence):
            yield sentence[nexti:]

    def viterbi(self, obs, states):
        V = [{}]  # tabular
        path = {}
        for y in states:  # init
            V[0][y] = self.prob_start[y] + self.prob_emit[y].get(obs[0], self.MIN_FLOAT)
            path[y] = [y]
        for t in range(1, len(obs)):
            V.append({})
            newpath = {}
            for y in states:
                em_p = self.prob_emit[y].get(obs[t], self.MIN_FLOAT)
                (prob, state) = max(
                    [(V[t - 1][y0] + self.prob_trans[y0].get(y, self.MIN_FLOAT) + em_p, y0) for y0 in states])
                V[t][y] = prob
                newpath[y] = path[state] + [y]
            path = newpath

        (prob, state) = max((V[len(obs) - 1][y], y) for y in 'ES')

        return (prob, path[state])

    def get_DAG(self, sentence):
        DAG = {}
        N = len(sentence)
        for k in range(N):
            tmplist = []
            i = k
            frag = sentence[k]
            while i < N and frag in self.freq:
                if self.freq[frag]:
                    tmplist.append(i)
                i += 1
                frag = sentence[k:i + 1]
            if not tmplist:
                tmplist.append(k)
            DAG[k] = tmplist
        return DAG

    def clac(self, sentence, DAG):
        n = len(sentence)
        route = {n: (0, 0)}
        log_total = log(self.total)

        for i in range(n-1, -1, -1):
            cache = []
            for j in DAG[i]:
                log_p = log(self.freq.get(sentence[i:j+1], 0) or 1)
                cache.append((log_p - log_total + route[j+1][0], j))
            route[i] = max(cache)
        return route


sentence1 = "程序员祝海林和朱会震是在孙健的左面和右面, 范凯在最右面。再往左是李松洪"
tokenizer = ChineseTokenizer()
list(tokenizer.cut(sentence1))